{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Automatic speech recognition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "hide_input": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/TksaY_FDgnk?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#@title\n",
    "from IPython.display import HTML\n",
    "\n",
    "HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/TksaY_FDgnk?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "自動音声認識 (ASR) は音声信号をテキストに変換し、一連の音声入力をテキスト出力にマッピングします。 Siri や Alexa などの仮想アシスタントは ASR モデルを使用してユーザーを日常的に支援しており、ライブキャプションや会議中のメモ取りなど、他にも便利なユーザー向けアプリケーションが数多くあります。\n",
    "\n",
    "このガイドでは、次の方法を説明します。\n",
    "\n",
    "1. [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) データセットの [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base) を微調整して、音声をテキストに書き起こします。\n",
    "2. 微調整したモデルを推論に使用します。\n",
    "\n",
    "<Tip>\n",
    "\n",
    "このタスクと互換性のあるすべてのアーキテクチャとチェックポイントを確認するには、[タスクページ](https://huggingface.co/tasks/automatic-speech-recognition) を確認することをお勧めします。\n",
    "\n",
    "</Tip>\n",
    "\n",
    "始める前に、必要なライブラリがすべてインストールされていることを確認してください。\n",
    "\n",
    "```bash\n",
    "pip install transformers datasets evaluate jiwer\n",
    "```\n",
    "\n",
    "モデルをアップロードしてコミュニティと共有できるように、Hugging Face アカウントにログインすることをお勧めします。プロンプトが表示されたら、トークンを入力してログインします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load MInDS-14 dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "まず、🤗 データセット ライブラリから [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) データセットの小さいサブセットをロードします。これにより、完全なデータセットのトレーニングにさらに時間を費やす前に、実験してすべてが機能することを確認する機会が得られます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, Audio\n",
    "\n",
    "minds = load_dataset(\"PolyAI/minds14\", name=\"en-US\", split=\"train[:100]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`~Dataset.train_test_split` メソッドを使用して、データセットの `train` 分割をトレイン セットとテスト セットに分割します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "minds = minds.train_test_split(test_size=0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、データセットを見てみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],\n",
       "        num_rows: 16\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],\n",
       "        num_rows: 4\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "minds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "データセットには`lang_id`や`english_transcription`などの多くの有用な情報が含まれていますが、このガイドでは「`audio`」と「`transciption`」に焦点を当てます。 `remove_columns` メソッドを使用して他の列を削除します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "minds = minds.remove_columns([\"english_transcription\", \"intent_class\", \"lang_id\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "もう一度例を見てみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'audio': {'array': array([-0.00024414,  0.        ,  0.        , ...,  0.00024414,\n",
       "          0.00024414,  0.00024414], dtype=float32),\n",
       "  'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',\n",
       "  'sampling_rate': 8000},\n",
       " 'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',\n",
       " 'transcription': \"hi I'm trying to use the banking app on my phone and currently my checking and savings account balance is not refreshing\"}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "minds[\"train\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次の 2 つのフィールドがあります。\n",
    "\n",
    "- `audio`: 音声ファイルをロードしてリサンプリングするために呼び出す必要がある音声信号の 1 次元の `array`。\n",
    "- `transcription`: ターゲットテキスト。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次のステップでは、Wav2Vec2 プロセッサをロードしてオーディオ信号を処理します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor\n",
    "\n",
    "processor = AutoProcessor.from_pretrained(\"facebook/wav2vec2-base\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MInDS-14 データセットのサンプリング レートは 8000kHz です (この情報は [データセット カード](https://huggingface.co/datasets/PolyAI/minds14) で確認できます)。つまり、データセットを再サンプリングする必要があります。事前トレーニングされた Wav2Vec2 モデルを使用するには、16000kHz に設定します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'audio': {'array': array([-2.38064706e-04, -1.58618059e-04, -5.43987835e-06, ...,\n",
       "          2.78103951e-04,  2.38446111e-04,  1.18740834e-04], dtype=float32),\n",
       "  'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',\n",
       "  'sampling_rate': 16000},\n",
       " 'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',\n",
       " 'transcription': \"hi I'm trying to use the banking app on my phone and currently my checking and savings account balance is not refreshing\"}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "minds = minds.cast_column(\"audio\", Audio(sampling_rate=16_000))\n",
    "minds[\"train\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上の `transcription` でわかるように、テキストには大文字と小文字が混在しています。 Wav2Vec2 トークナイザーは大文字のみでトレーニングされるため、テキストがトークナイザーの語彙と一致することを確認する必要があります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def uppercase(example):\n",
    "    return {\"transcription\": example[\"transcription\"].upper()}\n",
    "\n",
    "\n",
    "minds = minds.map(uppercase)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、次の前処理関数を作成します。\n",
    "\n",
    "1. `audio`列を呼び出して、オーディオ ファイルをロードしてリサンプリングします。\n",
    "2. オーディオ ファイルから `input_values` を抽出し、プロセッサを使用して `transcription` 列をトークン化します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_dataset(batch):\n",
    "    audio = batch[\"audio\"]\n",
    "    batch = processor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"], text=batch[\"transcription\"])\n",
    "    batch[\"input_length\"] = len(batch[\"input_values\"][0])\n",
    "    return batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "データセット全体に前処理関数を適用するには、🤗 Datasets `map` 関数を使用します。 `num_proc` パラメータを使用してプロセスの数を増やすことで、`map` を高速化できます。 `remove_columns` メソッドを使用して、不要な列を削除します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_minds = minds.map(prepare_dataset, remove_columns=minds.column_names[\"train\"], num_proc=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "🤗 Transformers には ASR 用のデータ照合器がないため、`DataCollat​​orWithPadding` を調整してサンプルのバッチを作成する必要があります。また、テキストとラベルが (データセット全体ではなく) バッチ内の最も長い要素の長さに合わせて動的に埋め込まれ、均一な長さになります。 `padding=True` を設定すると、`tokenizer` 関数でテキストを埋め込むことができますが、動的な埋め込みの方が効率的です。\n",
    "\n",
    "他のデータ照合器とは異なり、この特定のデータ照合器は、`input_values`と `labels`」に異なるパディング方法を適用する必要があります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from dataclasses import dataclass, field\n",
    "from typing import Any, Dict, List, Optional, Union\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class DataCollatorCTCWithPadding:\n",
    "    processor: AutoProcessor\n",
    "    padding: Union[bool, str] = \"longest\"\n",
    "\n",
    "    def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) -> dict[str, torch.Tensor]:\n",
    "        # split inputs and labels since they have to be of different lengths and need\n",
    "        # different padding methods\n",
    "        input_features = [{\"input_values\": feature[\"input_values\"][0]} for feature in features]\n",
    "        label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
    "\n",
    "        batch = self.processor.pad(input_features, padding=self.padding, return_tensors=\"pt\")\n",
    "\n",
    "        labels_batch = self.processor.pad(labels=label_features, padding=self.padding, return_tensors=\"pt\")\n",
    "\n",
    "        # replace padding with -100 to ignore loss correctly\n",
    "        labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
    "\n",
    "        batch[\"labels\"] = labels\n",
    "\n",
    "        return batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、`DataCollat​​orForCTCWithPadding` をインスタンス化します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_collator = DataCollatorCTCWithPadding(processor=processor, padding=\"longest\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "トレーニング中にメトリクスを含めると、多くの場合、モデルのパフォーマンスを評価するのに役立ちます。 🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) ライブラリを使用して、評価メソッドをすばやくロードできます。このタスクでは、[単語エラー率](https://huggingface.co/spaces/evaluate-metric/wer) (WER) メトリクスを読み込みます (🤗 Evaluate [クイック ツアー](https://huggingface.co/docs/evaluate/a_quick_tour) を参照して、メトリクスをロードして計算する方法の詳細を確認してください)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "wer = evaluate.load(\"wer\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、予測とラベルを `compute` に渡して WER を計算する関数を作成します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def compute_metrics(pred):\n",
    "    pred_logits = pred.predictions\n",
    "    pred_ids = np.argmax(pred_logits, axis=-1)\n",
    "\n",
    "    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id\n",
    "\n",
    "    pred_str = processor.batch_decode(pred_ids)\n",
    "    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)\n",
    "\n",
    "    wer = wer.compute(predictions=pred_str, references=label_str)\n",
    "\n",
    "    return {\"wer\": wer}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "これで`compute_metrics`関数の準備が整いました。トレーニングをセットアップするときにこの関数に戻ります。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<Tip>\n",
    "\n",
    "[Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) を使用したモデルの微調整に慣れていない場合は、[ここ](https://huggingface.co/docs/transformers/main/ja/tasks/../training#train-with-pytorch-trainer) の基本的なチュートリアルをご覧ください。\n",
    "\n",
    "</Tip>\n",
    "\n",
    "これでモデルのトレーニングを開始する準備が整いました。 [AutoModelForCTC](https://huggingface.co/docs/transformers/main/ja/model_doc/auto#transformers.AutoModelForCTC) で Wav2Vec2 をロードします。 `ctc_loss_reduction` パラメータで適用する削減を指定します。多くの場合、デフォルトの合計ではなく平均を使用する方が適切です。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCTC, TrainingArguments, Trainer\n",
    "\n",
    "model = AutoModelForCTC.from_pretrained(\n",
    "    \"facebook/wav2vec2-base\",\n",
    "    ctc_loss_reduction=\"mean\",\n",
    "    pad_token_id=processor.tokenizer.pad_token_id,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "この時点で残っている手順は次の 3 つだけです。\n",
    "\n",
    "1. [TrainingArguments](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.TrainingArguments) でトレーニング ハイパーパラメータを定義します。唯一の必須パラメータは、モデルの保存場所を指定する `output_dir` です。 `push_to_hub=True`を設定して、このモデルをハブにプッシュします (モデルをアップロードするには、Hugging Face にサインインする必要があります)。各エポックの終了時に、`トレーナー` は WER を評価し、トレーニング チェックポイントを保存します。\n",
    "2. トレーニング引数を、モデル、データセット、トークナイザー、データ照合器、および `compute_metrics` 関数とともに [Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) に渡します。\n",
    "3. [train()](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer.train) を呼び出してモデルを微調整します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"my_awesome_asr_mind_model\",\n",
    "    per_device_train_batch_size=8,\n",
    "    gradient_accumulation_steps=2,\n",
    "    learning_rate=1e-5,\n",
    "    warmup_steps=500,\n",
    "    max_steps=2000,\n",
    "    gradient_checkpointing=True,\n",
    "    fp16=True,\n",
    "    group_by_length=True,\n",
    "    eval_strategy=\"steps\",\n",
    "    per_device_eval_batch_size=8,\n",
    "    save_steps=1000,\n",
    "    eval_steps=1000,\n",
    "    logging_steps=25,\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model=\"wer\",\n",
    "    greater_is_better=False,\n",
    "    push_to_hub=True,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=encoded_minds[\"train\"],\n",
    "    eval_dataset=encoded_minds[\"test\"],\n",
    "    processing_class=processor,\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "トレーニングが完了したら、 [push_to_hub()](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer.push_to_hub) メソッドを使用してモデルをハブに共有し、誰もがモデルを使用できるようにします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<Tip>\n",
    "\n",
    "自動音声認識用にモデルを微調整する方法のより詳細な例については、英語 ASR および英語のこのブログ [投稿](https://huggingface.co/blog/fine-tune-wav2vec2-english) を参照してください。多言語 ASR については、この [投稿](https://huggingface.co/blog/fine-tune-xlsr-wav2vec2) を参照してください。\n",
    "\n",
    "</Tip>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "モデルを微調整したので、それを推論に使用できるようになりました。\n",
    "\n",
    "推論を実行したい音声ファイルをロードします。必要に応じて、オーディオ ファイルのサンプリング レートをモデルのサンプリング レートと一致するようにリサンプリングすることを忘れないでください。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, Audio\n",
    "\n",
    "dataset = load_dataset(\"PolyAI/minds14\", \"en-US\", split=\"train\")\n",
    "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
    "sampling_rate = dataset.features[\"audio\"].sampling_rate\n",
    "audio_file = dataset[0][\"audio\"][\"path\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "推論用に微調整されたモデルを試す最も簡単な方法は、それを [pipeline()](https://huggingface.co/docs/transformers/main/ja/main_classes/pipelines#transformers.pipeline) で使用することです。モデルを使用して自動音声認識用の`pipeline`をインスタンス化し、オーディオ ファイルをそれに渡します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': 'I WOUD LIKE O SET UP JOINT ACOUNT WTH Y PARTNER'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "transcriber = pipeline(\"automatic-speech-recognition\", model=\"stevhliu/my_awesome_asr_minds_model\")\n",
    "transcriber(audio_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<Tip>\n",
    "\n",
    "転写はまあまあですが、もっと良くなる可能性があります。さらに良い結果を得るには、より多くの例でモデルを微調整してみてください。\n",
    "\n",
    "</Tip>\n",
    "\n",
    "必要に応じて、「パイプライン」の結果を手動で複製することもできます。\n",
    "\n",
    "\n",
    "プロセッサをロードしてオーディオ ファイルと文字起こしを前処理し、`input`を PyTorch テンソルとして返します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor\n",
    "\n",
    "processor = AutoProcessor.from_pretrained(\"stevhliu/my_awesome_asr_mind_model\")\n",
    "inputs = processor(dataset[0][\"audio\"][\"array\"], sampling_rate=sampling_rate, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pass your inputs to the model and return the logits:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCTC\n",
    "\n",
    "model = AutoModelForCTC.from_pretrained(\"stevhliu/my_awesome_asr_mind_model\")\n",
    "with torch.no_grad():\n",
    "    logits = model(**inputs).logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最も高い確率で予測された `input_ids` を取得し、プロセッサを使用して予測された `input_ids` をデコードしてテキストに戻します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['I WOUL LIKE O SET UP JOINT ACOUNT WTH Y PARTNER']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "predicted_ids = torch.argmax(logits, dim=-1)\n",
    "transcription = processor.batch_decode(predicted_ids)\n",
    "transcription"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
